def merge_sort(nums:list[int]):
    if len(nums) <= 1: return nums
    n = len(nums)
    a = merge_sort(nums[:n//2])
    b = merge_sort(nums[n//2:])
    return merge(a,b)

def merge(a,b):
    ans = []
    i = 0
    j = 0
    while i < len(a) and j < len(b):
        if a[i] < b[j]:
            ans.append(a[i])
            i += 1
        else:
            ans.append(b[j])
            j += 1
    while i < len(a):
        ans.append(a[i])
        i += 1
    while j < len(b):
        ans.append(b[j])
        j += 1
    return ans


if __name__ == '__main__':
    r = merge_sort([123,1234,21,3,12,31,3,1,3])
    print(r)
